# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""[Experimental] Live Music API client."""

import contextlib
import json
import logging
from typing import AsyncIterator

from . import _api_module
from . import _common
from . import _live_converters as live_converters
from . import _transformers as t
from . import types
from ._api_client import BaseApiClient
from ._common import set_value_by_path as setv


try:
  from websockets.asyncio.client import ClientConnection
  from websockets.asyncio.client import connect
except ModuleNotFoundError:
  from websockets.client import ClientConnection  # type: ignore
  from websockets.client import connect  # type: ignore

logger = logging.getLogger('google_genai.live_music')


class AsyncMusicSession:
  """[Experimental] AsyncMusicSession."""

  def __init__(self, api_client: BaseApiClient, websocket: ClientConnection):
    self._api_client = api_client
    self._ws = websocket

  async def set_weighted_prompts(
      self, prompts: list[types.WeightedPrompt]
  ) -> None:
    if self._api_client.vertexai:
      raise NotImplementedError(
          'Live music generation is not supported in Vertex AI.'
      )
    else:
      client_content_dict = {
          'weightedPrompts': [
              _common.convert_to_dict(prompt, convert_keys=True)
              for prompt in prompts
          ]
      }

    await self._ws.send(json.dumps({'clientContent': client_content_dict}))

  async def set_music_generation_config(
      self, config: types.LiveMusicGenerationConfig
  ) -> None:
    if self._api_client.vertexai:
      raise NotImplementedError(
          'Live music generation is not supported in Vertex AI.'
      )
    else:
      config_dict = _common.convert_to_dict(config, convert_keys=True)
    await self._ws.send(json.dumps({'musicGenerationConfig': config_dict}))

  async def _send_control_signal(
      self, playback_control: types.LiveMusicPlaybackControl
  ) -> None:
    if self._api_client.vertexai:
      raise NotImplementedError(
          'Live music generation is not supported in Vertex AI.'
      )
    else:
      playback_control_dict = {'playbackControl': playback_control.value}
      await self._ws.send(json.dumps(playback_control_dict))

  async def play(self) -> None:
    """Sends playback signal to start the music stream."""
    return await self._send_control_signal(types.LiveMusicPlaybackControl.PLAY)

  async def pause(self) -> None:
    """Sends a playback signal to pause the music stream."""
    return await self._send_control_signal(types.LiveMusicPlaybackControl.PAUSE)

  async def stop(self) -> None:
    """Sends a playback signal to stop the music stream.

    Resets the music generation context while retaining the current config.
    """
    return await self._send_control_signal(types.LiveMusicPlaybackControl.STOP)

  async def reset_context(self) -> None:
    """Reset the context (prompts retained) without stopping the music generation."""
    return await self._send_control_signal(
        types.LiveMusicPlaybackControl.RESET_CONTEXT
    )

  async def receive(self) -> AsyncIterator[types.LiveMusicServerMessage]:
    """Receive model responses from the server.

    Yields:
      The audio chunks from the server.
    """
    # TODO(b/365983264) Handle intermittent issues for the user.
    while result := await self._receive():
      yield result

  async def _receive(self) -> types.LiveMusicServerMessage:
    parameter_model = types.LiveMusicServerMessage()
    try:
      raw_response = await self._ws.recv(decode=False)
    except TypeError:
      raw_response = await self._ws.recv()  # type: ignore[assignment]
    if raw_response:
      try:
        response = json.loads(raw_response)
      except json.decoder.JSONDecodeError:
        raise ValueError(f'Failed to parse response: {raw_response!r}')
    else:
      response = {}

    if self._api_client.vertexai:
      raise NotImplementedError('Live music generation is not supported in Vertex AI.')
    else:
      response_dict = response

    return types.LiveMusicServerMessage._from_response(
        response=response_dict, kwargs=parameter_model.model_dump()
    )

  async def close(self) -> None:
    """Closes the bi-directional stream and terminates the session."""
    await self._ws.close()


class AsyncLiveMusic(_api_module.BaseModule):
  """[Experimental] Live music module.

  Live music can be accessed via `client.aio.live.music`.
  """

  @_common.experimental_warning(
      'Realtime music generation is experimental and may change in future versions.'
  )
  @contextlib.asynccontextmanager
  async def connect(self, *, model: str) -> AsyncIterator[AsyncMusicSession]:
    """[Experimental] Connect to the live music server."""
    base_url = self._api_client._websocket_base_url()
    if isinstance(base_url, bytes):
      base_url = base_url.decode('utf-8')
    transformed_model = t.t_model(self._api_client, model)

    if self._api_client.api_key:
      api_key = self._api_client.api_key
      version = self._api_client._http_options.api_version
      uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.BidiGenerateMusic?key={api_key}'
      headers = self._api_client._http_options.headers

      # Only mldev supported
      request_dict = _common.convert_to_dict(
          live_converters._LiveMusicConnectParameters_to_mldev(
              from_object=types.LiveMusicConnectParameters(
                model=transformed_model,
              ).model_dump(exclude_none=True)
          )
      )

      setv(request_dict, ['setup', 'model'], transformed_model)

      request = json.dumps(request_dict)
    else:
      raise NotImplementedError('Live music generation is not supported in Vertex AI.')

    try:
      async with connect(uri, additional_headers=headers) as ws:
        await ws.send(request)
        logger.info(await ws.recv(decode=False))

        yield AsyncMusicSession(api_client=self._api_client, websocket=ws)
    except TypeError:
      # Try with the older websockets API
      async with connect(uri, extra_headers=headers) as ws:
        await ws.send(request)
        logger.info(await ws.recv())

        yield AsyncMusicSession(api_client=self._api_client, websocket=ws)
